from functools import partial

import pytest

import autograd.numpy as np
import autograd.numpy.random as npr
from autograd.test_util import check_grads

npr.seed(1)

### fwd mode not yet implemented
check_grads = partial(check_grads, modes=["rev"])


def test_fft():
    def fun(x):
        return np.fft.fft(x)

    D = 5
    mat = npr.randn(D, D)
    check_grads(fun)(mat)


def test_fft_ortho():
    def fun(x):
        return np.fft.fft(x, norm="ortho")

    D = 5
    mat = npr.randn(D, D)
    check_grads(fun)(mat)


def test_fft_axis():
    def fun(x):
        return np.fft.fft(x, axis=0)

    D = 5
    mat = npr.randn(D, D)
    check_grads(fun)(mat)


def match_complex(fft_fun, mat):
    # ensure hermitian by doing a fft
    if fft_fun.__name__.startswith("ir"):
        return getattr(np.fft, fft_fun.__name__[1:])(mat)
    else:
        return mat


def check_fft_n(fft_fun, D, n):
    def fun(x):
        return fft_fun(x, D + n)

    mat = npr.randn(D, D)
    mat = match_complex(fft_fun, mat)
    check_grads(fun)(mat)


def test_fft_n_smaller():
    check_fft_n(np.fft.fft, 5, -2)


def test_fft_n_bigger():
    check_fft_n(np.fft.fft, 5, 2)


def test_ifft_n_smaller():
    check_fft_n(np.fft.ifft, 5, -2)


def test_ifft_n_bigger():
    check_fft_n(np.fft.ifft, 5, 2)


def test_rfft_n_smaller():
    check_fft_n(np.fft.rfft, 4, -2)


def test_rfft_n_bigger():
    check_fft_n(np.fft.rfft, 4, 2)


def test_irfft_n_smaller():
    check_fft_n(np.fft.irfft, 4, -2)


def test_irfft_n_bigger():
    check_fft_n(np.fft.irfft, 4, 2)


def check_fft_s(fft_fun, D):
    def fun(x):
        return fft_fun(x, s=s, axes=axes)

    mat = npr.randn(D, D, D) / 10.0
    mat = match_complex(fft_fun, mat)
    s = [D + 2, D - 2]
    axes = [0, 2]
    check_grads(fun)(mat)


def test_fft2_s():
    check_fft_s(np.fft.fft2, 5)


def test_ifft2_s():
    check_fft_s(np.fft.ifft2, 5)


def test_fftn_s():
    check_fft_s(np.fft.fftn, 5)


def test_ifftn_s():
    check_fft_s(np.fft.ifftn, 5)


def test_rfft2_s():
    check_fft_s(np.fft.rfft2, 4)


def test_irfft2_s():
    check_fft_s(np.fft.irfft2, 4)


def test_rfftn_s():
    check_fft_s(np.fft.rfftn, 4)


def test_irfftn_s():
    check_fft_s(np.fft.irfftn, 4)


## TODO: fft gradient not implemented for repeated axes
# def test_fft_repeated_axis():
#     D = 5
#     for fft_fun in (np.fft.fft2,np.fft.ifft2,np.fft.fftn, np.fft.ifftn):
#        def fun(x): return fft_fun(x, s=s, axes=axes)

#        mat = npr.randn(D,D,D) / 10.0
#        s = [D + 2, D - 2]
#        axes = [0,0]

#   check_grads(rad)(fun)


def test_ifft():
    def fun(x):
        return np.fft.ifft(x)

    D = 5
    mat = npr.randn(D, D)
    check_grads(fun)(mat)


def test_fft2():
    def fun(x):
        return np.fft.fft2(x)

    D = 5
    mat = npr.randn(D, D) / 10.0
    check_grads(fun)(mat)


def test_ifft2():
    def fun(x):
        return np.fft.ifft2(x)

    D = 5
    mat = npr.randn(D, D)
    check_grads(fun)(mat)


def test_fftn():
    def fun(x):
        return np.fft.fftn(x)

    D = 5
    mat = npr.randn(D, D) / 10.0
    check_grads(fun)(mat)


def test_ifftn():
    def fun(x):
        return np.fft.ifftn(x)

    D = 5
    mat = npr.randn(D, D)
    check_grads(fun)(mat)


def test_rfft():
    def fun(x):
        return np.fft.rfft(x)

    D = 4
    mat = npr.randn(D, D) / 10.0
    check_grads(fun)(mat)


def test_rfft_ortho():
    def fun(x):
        return np.fft.rfft(x, norm="ortho")

    D = 4
    mat = npr.randn(D, D) / 10.0
    check_grads(fun)(mat)


def test_rfft_axes():
    def fun(x):
        return np.fft.rfft(x, axis=0)

    D = 4
    mat = npr.randn(D, D) / 10.0
    check_grads(fun)(mat)


def test_irfft():
    def fun(x):
        return np.fft.irfft(x)

    D = 4
    mat = npr.randn(D, D) / 10.0
    # ensure hermitian by doing a fft
    mat = np.fft.rfft(mat)
    check_grads(fun)(mat)


def test_irfft_ortho():
    def fun(x):
        return np.fft.irfft(x, norm="ortho")

    D = 4
    mat = npr.randn(D, D) / 10.0
    # ensure hermitian by doing a fft
    mat = np.fft.rfft(mat)
    check_grads(fun)(mat)


def test_rfft2():
    def fun(x):
        return np.fft.rfft2(x)

    D = 4
    mat = npr.randn(D, D) / 10.0
    check_grads(fun)(mat)


def test_irfft2():
    def fun(x):
        return np.fft.irfft2(x)

    D = 4
    mat = npr.randn(D, D) / 10.0
    # ensure hermitian by doing a fft
    mat = np.fft.rfft2(mat)
    check_grads(fun)(mat)


def test_rfftn():
    def fun(x):
        return np.fft.rfftn(x)

    D = 4
    mat = npr.randn(D, D, D) / 10.0
    check_grads(fun)(mat)


def test_rfftn_odd_not_implemented():
    def fun(x):
        return np.fft.rfftn(x)

    D = 5
    mat = npr.randn(D, D, D) / 10.0
    with pytest.raises(NotImplementedError):
        check_grads(fun)(mat)


def test_rfftn_subset():
    def fun(x):
        return np.fft.rfftn(x)[(0, 1, 0), (3, 3, 2)]

    D = 4
    mat = npr.randn(D, D, D) / 10.0
    check_grads(fun)(mat)


def test_rfftn_axes():
    def fun(x):
        return np.fft.rfftn(x, axes=(0, 2))

    D = 4
    mat = npr.randn(D, D, D) / 10.0
    check_grads(fun)(mat)


def test_irfftn():
    def fun(x):
        return np.fft.irfftn(x)

    D = 4
    mat = npr.randn(D, D, D) / 10.0
    # ensure hermitian by doing a fft
    mat = np.fft.rfftn(mat)
    check_grads(fun)(mat)


def test_irfftn_subset():
    def fun(x):
        return np.fft.irfftn(x)[(0, 1, 0), (3, 3, 2)]

    D = 4
    mat = npr.randn(D, D, D) / 10.0
    # ensure hermitian by doing a fft
    mat = np.fft.rfftn(mat)
    check_grads(fun)(mat)


def test_fftshift():
    def fun(x):
        return np.fft.fftshift(x)

    D = 5
    mat = npr.randn(D, D) / 10.0
    check_grads(fun)(mat)


def test_fftshift_even():
    def fun(x):
        return np.fft.fftshift(x)

    D = 4
    mat = npr.randn(D, D) / 10.0
    check_grads(fun)(mat)


def test_fftshift_axes():
    def fun(x):
        return np.fft.fftshift(x, axes=1)

    D = 5
    mat = npr.randn(D, D) / 10.0
    check_grads(fun)(mat)


def test_ifftshift():
    def fun(x):
        return np.fft.ifftshift(x)

    D = 5
    mat = npr.randn(D, D)
    check_grads(fun)(mat)


def test_ifftshift_even():
    def fun(x):
        return np.fft.ifftshift(x)

    D = 4
    mat = npr.randn(D, D)
    check_grads(fun)(mat)


def test_ifftshift_axes():
    def fun(x):
        return np.fft.ifftshift(x, axes=1)

    D = 5
    mat = npr.randn(D, D)
    check_grads(fun)(mat)
